import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class DQN(nn.Module):
    def __init__(self, num_inputs, block_hidden_dim, num_outputs):
        super(DQN, self).__init__()
        self.num_inputs = num_inputs
        self.num_outputs = num_outputs
        self.block_hidden_dim = block_hidden_dim
        self.fc1 = nn.Linear(self.num_inputs, self.block_hidden_dim)
        self.fc2 = nn.Linear(self.block_hidden_dim, self.block_hidden_dim)
        self.fc3 = nn.Linear(self.block_hidden_dim, 3)
        # nn.init.xavier_normal_(self.fc1.weight)
        # nn.init.xavier_normal_(self.fc2.weight)
        # nn.init.xavier_normal_(self.fc3.weight)

    def forward(self, state_action):
        x1 = F.relu(self.fc1(state_action))
        x2 = F.relu(self.fc2(x1))
        output = self.fc3(x2)
        return output


class Spline_DQN(nn.Module):
    def __init__(self, num_inputs, num_outputs, num_support, num_tau, block_hidden_dim, device):
        super(Spline_DQN, self).__init__()
        self.num_inputs = num_inputs
        self.num_outputs = num_outputs
        self.block_hidden_dim = block_hidden_dim
        self.K = num_support
        self.min_bin_width = 1e-3
        self.min_bin_height = 1e-3
        self.min_derivative = 1e-3
        self.num_tau = num_tau
        intvl = np.linspace(0.0, 1.0, self.num_tau + 1)[1:]
        self.tau = torch.Tensor(np.linspace(0.0, 1.0, self.num_tau + 1)[:-1] + intvl) / 2
        self.tau = self.tau.to(device)
        self.fc1 = nn.Linear(self.num_inputs, self.block_hidden_dim)
        self.fc2 = nn.Linear(self.block_hidden_dim, self.num_outputs * (3 * self.K - 1))
        self.alpha = nn.Linear(self.block_hidden_dim, self.num_outputs)
        self.beta = nn.Linear(self.block_hidden_dim, self.num_outputs)
        # for m in self.modules():
        #     if isinstance(m, nn.Linear):
        #         nn.init.xavier_uniform_(m.weight)

    def forward(self, state_action):
        batch_size = state_action.size(0)
        input_states = state_action
        x = F.relu(self.fc1(input_states))

        if len(x.shape) == 3:
            x = x.squeeze(1)

        # spline params
        spline_param = self.fc2(x)
        spline_param = spline_param.view(batch_size, self.num_outputs, (3 * self.K - 1))
        # scale for height
        #scale_a = self.alpha(x)
        #scale_a = torch.exp(scale_a)
        #scale_b = self.beta(x)

        #print('a', scale_a)

        # split the last dimention to Width, Height, Derivative
        W, H, D = torch.split(spline_param, self.K, dim=2)
        # Width in range [0, 1], Height in range [0, 1]
        W, H = torch.softmax(W, dim=2), torch.softmax(H, dim=2)
        W = self.min_bin_width + (1 - self.min_bin_width * self.K) * W
        H = self.min_bin_height + (1 - self.min_bin_height * self.K) * H

        D = self.min_derivative + F.softplus(D)
        D = F.pad(D, pad=(1, 1))
        constant = np.log(np.exp(1 - 1e-3) - 1)
        D[..., 0] = constant
        D[..., -1] = constant

        # start and end x of each bin
        cumwidths = torch.cumsum(W, dim=-1)
        cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
        cumwidths[..., -1] = 1.0  # (batch_sz, num_action, K+1)
        # print(cumwidths)
        widths = cumwidths[..., 1:] - cumwidths[..., :-1]  # (batch_sz, num_action, K)

        # start and end y of each bin, scale to [LB, UB]
        cumheights = torch.cumsum(H, dim=-1)
        cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
        #cumheights = scale_a.unsqueeze(2) * cumheights + scale_b.unsqueeze(2)
        # cumheights[..., 0] = self.LB
        # cumheights[..., -1] = self.UB   # (batch_sz, num_action, K+1)
        heights = cumheights[..., 1:] - cumheights[..., :-1]

        cumwidths_expand = cumwidths.unsqueeze(dim=2)
        cumwidths_expand = cumwidths_expand.expand(-1, -1, self.num_tau, -1)  # (batch_sz, num_action, num_tau, K+1)

        tau = self.tau.expand((batch_size, self.num_outputs, self.num_tau))

        # get the bin index for each tau
        bin_idx = self.searchsorted_(cumwidths_expand, tau)  # (batch_sz, num_action, num_tau)

        input_cumwidths = cumwidths.gather(-1, bin_idx)  # x_i
        input_bin_widths = widths.gather(-1, bin_idx)  # x_i+1 - x_i

        input_cumheights = cumheights.gather(-1, bin_idx)  # y_i
        input_heights = heights.gather(-1, bin_idx)  # y_i+1 - y_i

        delta = heights / widths

        input_delta = delta.gather(-1, bin_idx)  # (y_i+1 - y_i) / (x_i+1 - x_i)

        input_derivatives = D.gather(-1, bin_idx)  # d_i
        input_derivatives_plus_one = D[..., 1:].gather(-1, bin_idx)  # d_i+1

        # calculate quadratic spline for each tau
        theta = (tau - input_cumwidths) / input_bin_widths  # theta = (x - x_i) / (x_i+1 - x_i)
        # print(input_bin_widths)
        theta_one_minus_theta = theta * (1 - theta)  # theta * (1 - theta)

        numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
        denominator = input_delta + (
                    input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
        outputs = input_cumheights + numerator / denominator

        return outputs

    def searchsorted_(self, bin_locations, inputs):
        return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1


class Spline_DQN_Single(nn.Module):
    def __init__(self, num_inputs, num_support, num_tau, device):
        '''
        num_inputs: dim of input data (state dim + action dim)
        num_support: number of konts
        num_tau: number of quantiles for quantile regression
        '''
        super(Spline_DQN_Single, self).__init__()
        self.num_inputs = num_inputs
        self.K = num_support
        self.min_bin_width = 1e-3
        self.min_bin_height = 1e-3
        self.min_derivative = 1e-3
        self.num_tau = num_tau
        intvl = np.linspace(0.0, 1.0, self.num_tau + 1)[1:]
        self.tau = (np.linspace(0.0, 1.0, self.num_tau + 1)[:-1] + intvl) / 2

        self.fc1 = nn.Linear(self.num_inputs, 128)
        self.fc2 = nn.Linear(128, (3 * self.K - 1))  # output dim is 3*K - 1

        self.device = device

        # scale factor alpha and beta
        # 如果q表示概率的话，理论上可以不用scale
        # self.alpha = nn.Linear(128, 1)
        # self.beta = nn.Linear(128, 1)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform(m.weight)

    def forward(self, state):
        batch_size = state.size(0)
        x = F.relu(self.fc1(state))

        if len(x.shape) == 3:
            x = x.squeeze(1)

        # spline params
        spline_param = self.fc2(x)
        spline_param = spline_param.view(batch_size, (3 * self.K - 1))
        # scale for height
        # scale_a = self.alpha(x)
        # scale_a = torch.exp(scale_a)
        # scale_b = self.beta(x)

        # split the last dimention to Width, Height, Derivative
        W, H, D = torch.split(spline_param, self.K, dim=1)
        # Width in range [0, 1], Height in range [0, 1]
        W, H = torch.softmax(W, dim=1), torch.softmax(H, dim=1)
        W = self.min_bin_width + (1 - self.min_bin_width * self.K) * W
        H = self.min_bin_height + (1 - self.min_bin_height * self.K) * H
        # derivative
        D = self.min_derivative + F.softplus(D)
        D = F.pad(D, pad=(1, 1))
        constant = np.log(np.exp(1 - 1e-3) - 1)
        D[..., 0] = constant
        D[..., -1] = constant

        # start and end x(tau) of each bin
        cumwidths = torch.cumsum(W, dim=-1)
        cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
        cumwidths[..., -1] = 1.0
        widths = cumwidths[..., 1:] - cumwidths[..., :-1]  # (batch_sz, K)

        # start and end y(quantile value) of each bin
        cumheights = torch.cumsum(H, dim=-1)
        cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
        # cumheights = scale_a * cumheights + scale_b # 如果不 scale，可以不要这一行
        heights = cumheights[..., 1:] - cumheights[..., :-1]

        # uniformly choose tau
        tau = torch.Tensor(self.tau).to(self.device)
        tau = tau.expand((batch_size, self.num_tau))

        # get the bin index for each tau
        cumwidths_expand = cumwidths.unsqueeze(dim=1)  # (batch_sz, 1, K+1)
        cumwidths_expand = cumwidths_expand.expand(-1, self.num_tau, -1)  # (batch_sz, num_tau, K+1)

        bin_idx = self.searchsorted_(cumwidths_expand, tau)  # (batch_sz, num_tau)

        input_cumwidths = cumwidths.gather(-1, bin_idx)  # x_i
        input_bin_widths = widths.gather(-1, bin_idx)  # x_i+1 - x_i

        input_cumheights = cumheights.gather(-1, bin_idx)  # y_i
        input_heights = heights.gather(-1, bin_idx)  # y_i+1 - y_i

        delta = heights / widths

        input_delta = delta.gather(-1, bin_idx)  # (y_i+1 - y_i) / (x_i+1 - x_i)

        input_derivatives = D.gather(-1, bin_idx)  # d_i
        input_derivatives_plus_one = D[..., 1:].gather(-1, bin_idx)  # d_i+1

        # calculate quadratic spline for each tau
        theta = (tau - input_cumwidths) / input_bin_widths  # theta = (x - x_i) / (x_i+1 - x_i)

        theta_one_minus_theta = theta * (1 - theta)  # theta * (1 - theta)

        numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
        denominator = input_delta + (
                    input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
        outputs = input_cumheights + numerator / denominator

        return outputs

    def searchsorted_(self, bin_locations, inputs):
        return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1